Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use complex tensors in phase_vocoder #758

Closed
wants to merge 14 commits into from

Conversation

anjali411
Copy link

@anjali411 anjali411 commented Jul 1, 2020

This PR updates the phase_vocoder to use complex tensors. A local boolean flag USE_COMPLEX is added in torch.phase_vocoder to detect if the input is complex or not. If the input is complex, tensors with complex dtypes will be used in the implementation as well as a complex dtype tensor will be returned.

Test Plan: It adds JIT tests and new tests for complex dtype in separate files.

Documentation: reflects both the old API behavior which used (..., complex=2) dimension real tensors and new API behavior which uses complex dtype tensors.

Deprecation Warnings will be added at the end (before the release) when all functions support complex tensors.

In a follow up PR, torch.polar should be added to construct complex tensors using abs and angle in torch.functional.phase_vocoder.

@anjali411 anjali411 requested a review from vincentqb July 1, 2020 20:28
torchaudio/functional.py Outdated Show resolved Hide resolved
>>> # (channel, freq, time, complex=2)
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
>>> # (channel, freq, time)
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is neat!

torchaudio/functional.py Outdated Show resolved Hide resolved
@vincentqb
Copy link
Contributor

Are you planning to change all the implementations of complex numbers in torchaudio? We could add this to the release if it's ready and all the references are changed. Anything to be aware of? In particular, we are not testing autograd, and we have not made any commitments yet. Would this support autograd?

@mthrok mthrok changed the title Use complex tensors in phase_vocoder [PoC] Use complex tensors in phase_vocoder Jul 1, 2020
@vincentqb vincentqb marked this pull request as draft July 1, 2020 20:49
@anjali411
Copy link
Author

Are you planning to change all the implementations of complex numbers in torchaudio? We could add this to the release if it's ready and all the references are changed. Anything to be aware of? In particular, we are not testing autograd, and we have not made any commitments yet. Would this support autograd?

synced offline:

  1. We don't plan to include these changes in the upcoming release.
  2. We should return a complex tensor when the input is a complex tensor, and a real tensor when input is a real tensor
  3. There should be a deprecation warning if the user inputs a real tensor (..., 2)
  4. Ensure jit consistency tests run fine.
  5. Add tests for autograd?

@anjali411 anjali411 force-pushed the phase_vocoder branch 2 times, most recently from 89ba39b to 1ebd7e7 Compare July 2, 2020 00:03
@codecov
Copy link

codecov bot commented Jul 2, 2020

Codecov Report

❗ No coverage uploaded for pull request base (master@4a8610f). Click here to learn what that means.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff            @@
##             master     #758   +/-   ##
=========================================
  Coverage          ?   89.17%           
=========================================
  Files             ?       32           
  Lines             ?     2512           
  Branches          ?        0           
=========================================
  Hits              ?     2240           
  Misses            ?      272           
  Partials          ?        0           
Impacted Files Coverage Δ
torchaudio/transforms.py 95.75% <ø> (ø)
torchaudio/functional.py 95.19% <100.00%> (ø)
torchaudio/_internal/misc_ops.py 76.47% <0.00%> (ø)
torchaudio/datasets/yesno.py 78.26% <0.00%> (ø)
torchaudio/backend/utils.py 87.80% <0.00%> (ø)
torchaudio/datasets/utils.py 50.00% <0.00%> (ø)
torchaudio/backend/common.py 100.00% <0.00%> (ø)
torchaudio/compliance/__init__.py 100.00% <0.00%> (ø)
torchaudio/datasets/gtzan.py 75.92% <0.00%> (ø)
... and 24 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4a8610f...f3a6ac9. Read the comment docs.

@anjali411
Copy link
Author

anjali411 commented Jul 2, 2020

@vincentqb @mthrok The CI for this PR looks green (not sure how the codecov report is generated).

Now on further look at the torchaudio repo, it seems like almost all functions in functional.py will have to be modified to have them correctly work with complex numbers. These are some possible reasons why the functions would have to be modified:

  1. The function expects as input or returns a real representation (..., 2) instead of a complex tensor. In this case, the implementation will also have to be modified accordingly.
  2. The function doesn't have any (..., 2) assumptions but uses functions like torch.clamp whose behavior is not intuitively defined for complex numbers and is currently disable in torch.
  3. The functions are using internally using workaround functions for angle, norm etc.

Now the functions in functional.py are called at different places in the code and so there would a lot more sites where modifications might have to be made.

TorchScript

  1. tensor.real and tensor.imag don't currently work with torchscript. A current workaround for this issue is to use functions torch.real and torch.imag instead but it would be useful to have the tensor attributes working especially for the case where we would need the setter.
  2. TorchScript consistency tests first explicitly convert the dtype of the tensors to float/double, but this behavior will no longer be correct to test since after above modifications some functions will expect as input complex tensors.

Question for you:

Should we add a warning in this release that torchaudio will switch to complex tensors in the next release so that we don't have to add these if-else blocks at every site:

if input.is_complex():
    do ...
else:
    do ...

Things to figure out:

Autograd requirements for torchaudio to be able to adapt complex numbers.

@vincentqb
Copy link
Contributor

vincentqb commented Jul 6, 2020

torch.clamp whose behavior is not intuitively defined for complex numbers and is currently disable in torch.

How about truncating the norm of a complex number in polar coordinate? i.e. (1) convert to polar coordinate, (2) clamp the norm, (3) convert back to cartesian coordinate. This has the advantage of reducing to clamp [-x, x] for real numbers? I guess this wouldn't work as well for asymmetric interval clamp. Other cases to thing about?

phase_acc = torch.cumsum(phase, -1)

mag = alphas * norm_1 + (1 - alphas) * norm_0

real_stretch = mag * torch.cos(phase_acc)
imag_stretch = mag * torch.sin(phase_acc)

complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)
complex_specgrams_stretch = torch.view_as_complex(torch.stack([real_stretch, imag_stretch], dim=-1))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems pretty clunky, is there not a better way of doing it?

Copy link
Author

@anjali411 anjali411 Jul 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed! After this PR: pytorch/pytorch#39617, we can just rewrite this as

torch.complex_polar(mag, phase_acc)

instead of:

real_stretch = mag * torch.cos(phase_acc)
imag_stretch = mag * torch.sin(phase_acc)
complex_specgrams_stretch = torch.view_as_complex(torch.stack([real_stretch, imag_stretch], dim=-1))

@@ -458,68 +458,67 @@ def phase_vocoder(
factor of ``rate``.

Args:
complex_specgrams (Tensor): Dimension of `(..., freq, time, complex=2)`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something like this?

        complex_specgrams (Tensor): Either a real tensor of dimension of `(..., freq, time, complex=2)`
            or a tensor of dimension `(..., freq, time)` with complex dtype.

We were using "complex tensor" to mean (..., complex=2). This is now ambiguous. What expression do you recommend to refer to a tensor of complex dtype? "tensor with a complex dtype"?

Copy link
Author

@anjali411 anjali411 Aug 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah tensor with a complex dtype sounds good. However this "or" way of documenting could be problematic in case where the function takes more than one complex tensors. Perhaps in those cases, we can add a note stating that either all inputs should be real tensors or all inputs should be of complex dtype.

I think it might be nicer to add a separate example with complex dtype tensors so that it's also clear that the returned output would also be complex (if applicable) etc., especially since we are planning to switch to using complex dtype tensors in the release after the upcoming release

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do find this "or" way of discussing this a little cumbersome, and I agree this will get long if many tensors are involved. We could add a note in each, and just define the args/returns with complex dtype. We can still keep the example for clarity.

"""
    We are migrating to complex-dtype tensors. For backward compatibility reason,
    this function still supports the legacy convention of ending with a dimension of 2
    to represent a complex tensor.

    Args:
        complex_specgrams (Tensor): A tensor of dimension `(..., freq, time)` with complex dtype.
        rate (float): Speed-up factor
        phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
    Returns:
        Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate))`
            with a complex dtype.

    Example

    Example - Legacy
"""

Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P.S. Good suggestion below for example naming

rate (float): Speed-up factor
phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)

Returns:
Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate), complex=2)`
Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate))`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on the comment above, this would become:

        Tensor: Complex Specgrams Stretch, represented either as a real tensor with dimension of
            `(..., freq, ceil(time/rate), complex=2)` or a tensor of dimension `(..., freq, time)` with complex dtype.

thoughts?

>>> rate = 1.3 # Speed up by 30%
>>> phase_advance = torch.linspace(
>>> 0, math.pi * hop_length, freq)[..., None]
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231, 2])
torch.Size([2, 1025, 231])
Copy link
Contributor

@vincentqb vincentqb Aug 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might not need to change the example. we could add a second example, or a comment next to each , 2. other ideas?

Copy link
Author

@anjali411 anjali411 Aug 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have an example with tensors of complex dtype so that users know how to deal with complex tensors. This is an option:

(Old API) Example:
....

(New API) Example:
...

what do you think?

Copy link
Contributor

@vincentqb vincentqb Aug 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion :)

How about standardizing on this?

    Example - New API (using tensors with complex dtype)
    Example - Old API (using tensors with (..., complex=2))

@vincentqb
Copy link
Contributor

For deprecations to switch between (real, 2) and complex-dtype, the options are:

  1. Detect dtype automatically, and warn if user uses real tensors (and maybe have a local keyword argument to suppress the warning).
  2. Have a local keyword argument to switch the behavior, with default None that raises a warning.
  3. Have a global flag to switch the behavior, with default None that raises a warning.
  4. Have separate functionals/transforms with complex dtype.

I'm leaning for 1 without a warning until we are ready for deprecation. At that point, we'll just tell the user that to suppress the warning they need to convert to complex dtype. Thoughts?

@vincentqb
Copy link
Contributor

For testing, you can copy existing test into a new test file prefixed with complex_ to make it clear the test is almost the same but for complex.

@anjali411
Copy link
Author

For deprecations to switch between (real, 2) and complex-dtype, the options are:

  1. Detect dtype automatically, and warn if user uses real tensors (and maybe have a local keyword argument to suppress the warning).
  2. Have a local keyword argument to switch the behavior, with default None that raises a warning.
  3. Have a global flag to switch the behavior, with default None that raises a warning.
  4. Have separate functionals/transforms with complex dtype.

I'm leaning for 1 without a warning until we are ready for deprecation. At that point, we'll just tell the user that to suppress the warning they need to convert to complex dtype. Thoughts?

Yeah I would also agree that having local flags for each function makes sense since there will be jit issues with a global flag. And once fully migrated, we should generate warning everytime a user uses the (..., complex=2) real tensor.

@mthrok
Copy link
Collaborator

mthrok commented Aug 10, 2020

For testing, you can copy existing test into a new test file prefixed with complex_ to make it clear the test is almost the same but for complex.

@vincentqb @anjali411

I do not see an advantage or necessity to put these test suites in a separate files, unless these complex types are not supported in fbcode (which is). Especially, because this is not a new function, module or test category. Yet the test suite class is a new due to the different type, which is good enough, so putting them in the same existing class files make more sense.

@anjali411
Copy link
Author

anjali411 commented Aug 10, 2020

For testing, you can copy existing test into a new test file prefixed with complex_ to make it clear the test is almost the same but for complex.

@vincentqb @anjali411

I do not see an advantage or necessity to put these test suites in a separate files, unless these complex types are not supported in fbcode (which is). Especially, because this is not a new function, module or test category. Yet the test suite class is a new due to the different type, which is good enough, so putting them in the same existing class files make more sense.

I am fine either way. We can also add a new class in the same file! @mthrok will the tests in the newly added files not run in fbcode?

@mthrok
Copy link
Collaborator

mthrok commented Aug 10, 2020

I am fine either way. We can also add a new class in the same file! @mthrok will the tests in the newly added files not run in fbcode?

@anjali411 They do, if we add the definition for the new file. But to me, it makes sense to separate them if fbcode cannot run complex type (but I bet it can because pytorch's always the latest version). Adding the tests to existing file would just start running these tests automatically.


Example
Example - old API
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the discussion above, it is not clear which API is "old" and which one is "new". More direct to say "with real tensor input" and "with complex tensor input"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I agree that's more clear! updated


norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1)
norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1)
if use_complex:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an alternative to duplicating all the logic here, you could have instead taken the real tensor, viewed it as complex, and then used the complex codepath (viewing it back as real in the end). Something to consider?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a possibility however the goal is to be able to remove the code in if not use_complex branch after a deprecation cycle and just use the code in the other branch (which has similar logic, however there are some substantial differences, e.g., padding logic).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, in my suggestion, you'd delete the real code immediately :) Anyway, this is NBD

Copy link
Author

@anjali411 anjali411 Aug 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry I misread your comment and thought you meant the other way round! yeah that sounds reasonable to me. cc. @mthrok thoughts?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on more thought, current lack of autograd support and testing for complex, might introduce bc breaking changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry I misread your comment and thought you meant the other way round! yeah that sounds reasonable to me. cc. @mthrok thoughts?
on more thought, current lack of autograd support and testing for complex, might introduce bc breaking changes.

I am in favor for not duplicating the logic, however if that introduces BC breaking on real value tensor input, then I think we can wait until the autograd support arrives.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's R2R with complex insides, the choice of JAX/TF convention doesn't matter, you'll always get the same gradients in the end.

@vincentqb vincentqb marked this pull request as ready for review August 18, 2020 17:02
@anjali411 anjali411 marked this pull request as draft August 18, 2020 17:02
@anjali411 anjali411 marked this pull request as ready for review August 18, 2020 17:03
@vincentqb
Copy link
Contributor

@anjali411 --TimeStretch wraps phase_vocoder as an torch.nn.Module. Have you tried to run the backward pass with TimeStretch the following?

  • before this pull request with a real tensor
  • after this pull request with a real tensor
  • after this pull request with a complex tensor

This could make for an interesting new test: not to check correctness of the result, but just whether autograd runs without errors.

mthrok pushed a commit to mthrok/audio that referenced this pull request Feb 26, 2021
mthrok pushed a commit to mthrok/audio that referenced this pull request Feb 26, 2021
Adding Zafar's changes from PR pytorch#758 to run flags + gitignore additions
@mthrok
Copy link
Collaborator

mthrok commented Apr 1, 2021

To be replaced by #1410

@mthrok mthrok closed this Apr 1, 2021
mthrok added a commit that referenced this pull request Apr 2, 2021
1. `F.phase_vocoder` accepts Tensor with complex dtype.
    * The implementation path has been updated from #758 so that they share the same code path by internally converting the input Tensor to complex dtype and performing all the operation on top of it.
    * Adopted `torch.polar` for simpler Tensor generation from magnitude and angle.
2. Updated tests
    * librosa compatibility test for complex dtype and pseudo complex dtype
        * Extracted the output shape check test and moved it to functional so that it will be tested on all the combination of `{CPU | CUDA} x {complex64 | complex128}`
    * TorchScript compatibility test for `F.phase_vocoder` and `T.TimeStretch`.
    * batch consistency test for `T.TimeStretch`.
@mthrok mthrok modified the milestones: Complex Tensor Migration, v0.9 Apr 5, 2021
mpc001 pushed a commit to mpc001/audio that referenced this pull request Aug 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants